
import numpy as np
import gudhi as gd
import multipers as mp
from tqdm import tqdm
from itertools import product
from sklearn.neighbors import KernelDensity
from sklearn.base import BaseEstimator, TransformerMixin
from warnings import warn
from .signed_betti import *
from .invariants_with_persistable import *
from joblib import delayed
from types import FunctionType
from joblib import Parallel, delayed
from os.path import exists
from typing import Iterable
from torch import Tensor
import pandas as pd
from warnings import warn
import matplotlib.pyplot as plt
import MDAnalysis
from scipy.spatial import distance_matrix
from scipy.ndimage import gaussian_filter

def get_simplextree(x)->mp.SimplexTreeMulti:
	if isinstance(x, mp.SimplexTreeMulti):
		return x
	if len(x) == 3 and isinstance(x[0],FunctionType):
		f,args, kwargs = x
		return f(*args,**kwargs)
	else:
		raise TypeError("Not a valid SimplexTree !")
	return

def infer_grid_from_points(pts:Iterable[np.ndarray], num:int, strategy:str):
	if strategy =="regular":
		min = np.min(pts, axis=0)
		max = np.max(pts, axis=0)
		return np.linspace(min, max, num=num).T
	if strategy =="quantile":
		return np.quantile(pts, q=np.linspace(0,1,num), axis=0).T

	raise Exception(f"Grid strategy {strategy} not implemented")


# filtration_values = np.concatenate(Parallel(n_jobs=self.n_jobs)(delayed(get_filtration_values)(g,f=self.f) for g in X))
# 		self.range= tuple(np.quantile(filtration_values, [self.quantile, 1-self.quantile]))
################################################# Data2SimplexTree
class RipsDensity2SimplexTree(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=1, threshold:float=np.inf, 
	sparse:float|None=None, num_collapse:int=0, 
	num_parameters:int=2, kernel:str="gaussian", delayed=False, rescale_density:float=0,
	progress:bool=False) -> None:
		super().__init__()
		self.bandwidth=bandwidth
		self.threshold = threshold
		self.sparse=sparse
		self.num_collapse=num_collapse
		self.num_parameters = num_parameters
		self.kernel = kernel
		self.delayed=delayed
		self.rescale_density = rescale_density
		self.progress=progress
		return
	def fit(self, X:np.ndarray|list, y=None):
		return self

	
	def transform(self,X):
		kde:KernelDensity=KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel)
		def get_st(x)->mp.SimplexTreeMulti:
			st = gd.RipsComplex(points = x, max_edge_length=self.threshold, sparse=self.sparse).create_simplex_tree(max_dimension=1)
			st = mp.SimplexTreeMulti(st, num_parameters = self.num_parameters)
			kde.fit(x)
			codensity = -kde.score_samples(x)
			if self.rescale_density != 0:
				codensity -= codensity.min()
				if codensity.max() != 0:	codensity /= codensity.max()
				codensity *= self.rescale_density
			st.fill_lowerstar(codensity, parameter = 1)
			st.collapse_edges(num=self.num_collapse)
			st.collapse_edges(num=self.num_collapse, strong = False, max_dimension = 1) 
			return st
		with tqdm(X, desc="Computing simplextrees", disable= not self.progress and self.delayed) as data:
			if self.delayed:
				return [delayed(get_st)(x) for x in data] # delay the computation for the to_module pipe, as simplextrees are not pickle-able.
			return Parallel(n_jobs=-1, prefer="threads")(delayed(get_st)(x) for x in data) # not picklable so prefer threads is necessary.

class SimplexTreeEdgeCollapser(BaseEstimator, TransformerMixin):
	def __init__(self, num_collapses:int=0, full:bool=False, max_dimension:int|None=None) -> None:
		super().__init__()
		self.full=full
		self.num_collapses=num_collapses
		self.max_dimension=max_dimension
		return
	def fit(self, X:np.ndarray|list, y=None):
		return self
	def transform(self,X):
		for i in range(len(X)):
			X[i].collapse_edges(full=self.full, num=self.num_collapses)
		return X

#### MOLECULE DATA
# def _lines2bonds(_lines):
# 	out = []
# 	index = 0
# 	while index < len(_lines) and  _lines[index].strip() != "@<TRIPOS>BOND":
# 		index += 1
# 	index += 1
# 	while index < len(_lines) and  _lines[index].strip()[0] != "@":
# 		line = _lines[index].strip().split(" ")
# 		for j,truc in enumerate(line):
# 			line[j] = truc.strip()
# 		# try:
# 		out.append([int(stuff) for stuff in line if len(stuff) > 0])
# 		# except:
# 		# 	print_lin
# 		index +=1
# 	out = pd.DataFrame(out, columns=["bond_id","atom1", "atom2", "bond_type"])
# 	out.set_index(["bond_id"],inplace=True)
# 	return out
# def _get_mol2_file(path:str, num_cols:int=9, columns:dict|None=None):
# 	from biopandas.mol2 import split_multimol2,PandasMol2
# 	columns={
# 		0:('atom_id', int), 
# 		1:('atom_name', str),
# 		2:('x', float), 
# 		3:('y', float), 
# 		4:('z', float), 
# 		5:('atom_type', str), 
# 		6:('subst_id', int), 
# 		7:('subst_name', str), 
# 		8:('charge', float)
# 	} if columns is None else columns
# 	while len(columns) > num_cols:
# 		columns.pop(len(columns)-1)
# 	# try:
# 	molecules_dfs = []
# 	bonds_dfs = []
# 	for molecule in split_multimol2(path):
# 		_code, _lines = molecule
# 		try:
# 			bonds_dfs.append(_lines2bonds(_lines))
# 			molecule_df = PandasMol2().read_mol2_from_list(mol2_lines=_lines, mol2_code=_code, columns=columns).df
# 		except:
# 			print(_code)
# 			print(_lines)
# 		molecule_df.set_index(["atom_id"], inplace=True)
# 		molecules_dfs.append(molecule_df)        
# 	# except:
# 	#     return get_mol2_file(path=path, num_cols=num_cols-1)
# 	return molecules_dfs, bonds_dfs
# def _atom_to_mass(atom)->int:
# 	return ELEMENTS[atom].mass
# 	raise Exception(f" Atom {atom} has no registered mass.")

def _mol2st(path:str, bonds:bool = True, charge:bool=True, atomic_mass:bool=True, **kwargs):
	from MDAnalysis import Universe
	from MDAnalysis.topology.guessers import guess_masses
	molecule = Universe(path)
	# if isinstance(bonds_df, list):	
	# 	if len(bonds_df) > 1:	warn("Multiple molecule found in the same data ! Taking the first only.")
	# 	molecule_df = molecule_df[0]
	# 	bonds_df = bonds_df[0]
	num_filtrations = bonds + charge + atomic_mass
	nodes = molecule.atoms.indices.reshape(1,-1)
	edges = molecule.bonds.dump_contents().T
	num_vertices = nodes.shape[1]
	num_edges =edges.shape[1]
	
	st = mp.SimplexTreeMulti(num_parameters = num_filtrations)
	## Nodes filtration
	nodes_filtrations = np.zeros((num_vertices,num_filtrations), dtype=np.float32)
	st.insert_batch(nodes, nodes_filtrations)

	## Edges filtration
	# edges = np.array(bonds_df[["atom1", "atom2"]]).T
	edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32)
	if bonds:
		bond_length = molecule.bonds.bonds()
		edges_filtration[:,0] = bond_length
	st.insert_batch(edges, edges_filtration)
	if charge:
		raise Exception("TODO")
	if atomic_mass:
		masses = guess_masses(molecule.atoms.names)
		st.fill_lowerstar(masses, parameter=bonds+charge)
	return st
class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
	def __init__(self, atom_columns:Iterable[str]|None=None, atom_num_columns:int=9, max_dimension:int|None=None, delayed:bool=False, progress:bool=False, charge_filtration:bool=True, atomic_mass_filtration:bool=True, bonds_filtration:bool=True,n_jobs:int=1) -> None:
		super().__init__()
		self.max_dimension=max_dimension
		self.delayed=delayed
		self.progress=progress
		self.charge_filtration = charge_filtration
		self.atomic_mass_filtration = atomic_mass_filtration
		self.bonds_filtration = bonds_filtration
		self.n_jobs = n_jobs
		self.atom_columns = atom_columns
		self.atom_num_columns = atom_num_columns
		return
	def fit(self, X:Iterable[str], y=None):
		if len(X) == 0:	return self
		return self
	def transform(self,X:Iterable[str]):
		def to_simplex_tree(path_to_mol2_file:str):
			simplex_tree = _mol2st(path=path_to_mol2_file, 
				bonds=self.bonds_filtration,
				charge=self.charge_filtration,
				atomic_mass=self.atomic_mass_filtration,
			)
			return simplex_tree
		if self.delayed:
			return [delayed(to_simplex_tree)(path) for path in X]
		return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(to_simplex_tree)(path) for path in X)
		
############################################### Data2Signedmeasure

def tensor_möbius_inversion(tensor:Tensor|np.ndarray, grid_conversion:Iterable[np.ndarray]|None = None, plot:bool=False, raw:bool=False, num_parameters:int|None=None):
	betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
	num_indices, num_pts = betti_sparse.indices().shape
	num_parameters = num_indices if num_parameters is None else num_parameters
	if num_indices == num_parameters: # either hilbert or rank invariant
		rank_invariant = False
	elif 2*num_parameters == num_indices:
		rank_invariant = True
	else:
		raise TypeError(f"Unsupported betti shape. {num_indices} has to be either {num_parameters} or {2*num_parameters}.")
	points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
	weights = np.asarray(betti_sparse.values(), dtype=int)

	if grid_conversion is not None:
		coords = np.empty(shape=(num_pts,num_indices), dtype=float)
		for i in range(num_indices):
			coords[:,i] = grid_conversion[i%num_parameters][points_filtration[:,i]]
	else:
		coords = points_filtration
	if (not rank_invariant) and plot:
		plt.scatter(points_filtration[:,0],points_filtration[:,1], c=weights)
		plt.colorbar()
	if (not rank_invariant) or raw: return coords, weights
	def _is_trivial(rectangle:np.array):
		birth=rectangle[:num_parameters]
		death=rectangle[num_parameters:]
		return np.all(birth<=death) # and not np.array_equal(birth,death)
	correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
	if len(correct_indices) == 0:	return np.empty((0, num_indices)), np.empty((0))
	signed_measure = np.asarray(coords[correct_indices])
	weights = weights[correct_indices]
	if plot:
		assert signed_measure.shape[1] == 4 # plot only the rank decompo for the moment
		from matplotlib.pyplot import plot
		def _plot_rectangle(rectangle:np.ndarray, weight:float):
			x_axis=rectangle[[0,2]]
			y_axis=rectangle[[1,3]]
			color = "blue" if weight > 0 else "red"
			plot(x_axis, y_axis, c=color)
		for rectangle, weight in zip(signed_measure, weights):
			_plot_rectangle(rectangle=rectangle, weight=weight)
	return signed_measure, weights


# class DegreeRips2SignedMeasureH0(BaseEstimator, TransformerMixin):
# 	def __init__(self, min_rips_value, max_rips_value,max_normalized_degree, min_normalized_degree, grid_granularity, progress:bool=False, n_jobs=1) -> None:
# 		super().__init__()
# 		self.min_rips_value = min_rips_value
# 		self.max_rips_value = max_rips_value
# 		self.min_normalized_degree = min_normalized_degree
# 		self.max_normalized_degree = max_normalized_degree
# 		self.grid_granularity = grid_granularity
# 		self.progress=progress
# 		self.n_jobs = n_jobs
# 		return
# 	def fit(self, X:np.ndarray|list, y=None):
# 		return self
	
# 	def _transform1(self, data:np.ndarray):
# 		rips_values, normalized_degree_values, hilbert_function = hf_h0_degree_rips(
# 			data,
# 			min_rips_value = self.min_rips_value,
# 			max_rips_value = self.max_rips_value,
# 			min_normalized_degree = self.min_normalized_degree,
# 			max_normalized_degree = self.max_normalized_degree,
# 			grid_granularity = self.grid_granularity,
# 		)
# 		points, weights = betti_tensor2signed_measure(
# 			signed_betti(hilbert_function, threshold=True),
# 			grid_conversion=[rips_values, normalized_degree_values]
# 		)
# 		return [(points, weights)] # inside a list, for next pipelines. In general we may want multiple degrees -> multiple signed measure per data

# 	def transform(self,X):
# 		return Parallel(n_jobs=self.n_jobs)(delayed(DegreeRips2SignedMeasureH0._transform1)(self, data) 
# 		for data in tqdm(X, desc="Computing DegreeRips H0 signed measure.", disable = not self.progress))


class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
	def __init__(self, degrees:Iterable[int], min_rips_value:float, max_rips_value,max_normalized_degree:float, min_normalized_degree:float, grid_granularity:int, progress:bool=False, n_jobs=1, sparse:bool=False) -> None:
		super().__init__()
		assert 0 <= min_rips_value <= max_rips_value
		assert 0 <= min_normalized_degree <= max_normalized_degree
		assert grid_granularity > 0
		self.min_rips_value = min_rips_value
		self.max_rips_value = max_rips_value
		self.min_normalized_degree = min_normalized_degree
		self.max_normalized_degree = max_normalized_degree
		self.grid_granularity = grid_granularity
		self.progress=progress
		self.n_jobs = n_jobs
		self.degrees = degrees
		self.sparse=sparse
		return
	def fit(self, X:np.ndarray|list, y=None):
		return self
	
	def _transform1(self, data:np.ndarray):
		_distance_matrix = distance_matrix(data, data)
		signed_measures = []
		for degree in self.degrees:
			rips_values, normalized_degree_values, hilbert_function = hf_degree_rips(
				_distance_matrix,
				min_rips_value = self.min_rips_value,
				max_rips_value = self.max_rips_value,
				min_normalized_degree = self.min_normalized_degree,
				max_normalized_degree = self.max_normalized_degree,
				grid_granularity = self.grid_granularity,
				homological_dimension = degree,
			)
			signed_measure = signed_betti(hilbert_function, threshold=True)
			if self.sparse:
				signed_measure = tensor_möbius_inversion(
					tensor=signed_measure,num_parameters=2,
					grid_conversion=[rips_values, normalized_degree_values]
				)

			signed_measures.append(signed_measure)
		return signed_measures
	def transform(self,X):
		return Parallel(n_jobs=self.n_jobs)(delayed(DegreeRips2SignedMeasure._transform1)(self, data) 
		for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}, signed measures.", disable = not self.progress))






################################################# SimplexTree2...

def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
	grid_iterator = np.asarray(list(product(*filtration_grid)))
	grid_shape = [len(f) for f in filtration_grid]
	if len(pts) == 0:
		warn("Found a trivial signed measure !")
		return np.zeros(shape=grid_shape)
	kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
	
	pos_indices = pts_weights>0
	neg_indices = pts_weights<0
	img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
	img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
	return np.exp(img_pos) - np.exp(img_neg)



def _st2ranktensor(st:mp.SimplexTreeMulti, filtration_grid:np.ndarray, degree:int, plot:bool, reconvert_grid:bool, num_collapse:int|str=0):
	"""
	TODO
	"""
	## Copy (the squeeze change the filtration values)
	stcpy = mp.SimplexTreeMulti(st)
	# turns the simplextree into a coordinate simplex tree
	stcpy.grid_squeeze(
		filtration_grid = filtration_grid, 
		coordinate_values = True)
	# stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
	if num_collapse == "full":
		stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree+1)
	elif isinstance(num_collapse, int):
		stcpy.collapse_edges(num=num_collapse,ignore_warning=True, max_dimension=degree+1)
	else:
		raise TypeError(f"Invalid num_collapse={num_collapse} type. Either full, or an integer.")
	# computes the rank invariant tensor
	rank_tensor = mp.rank_invariant2d(stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid])
	# refactor this tensor into the rectangle decomposition of the signed betti
	grid_conversion = filtration_grid if reconvert_grid else None 
	rank_decomposition = rank_decomposition_by_rectangles(
		rank_tensor, threshold=True,
		)
	rectangle_decomposition = tensor_möbius_inversion(tensor = rank_decomposition, grid_conversion = grid_conversion, plot=plot, num_parameters=st.num_parameters)
	return rectangle_decomposition

class SimplexTree2RectangleDecomposition(BaseEstimator,TransformerMixin):
	"""
	Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition. 
	"""
	def __init__(self, filtration_grid:np.ndarray, degrees:Iterable[int], plot=False, reconvert_grid=True, num_collapses:int=0):
		super().__init__()
		self.filtration_grid = filtration_grid
		self.degrees = degrees
		self.plot=plot
		self.reconvert_grid = reconvert_grid
		self.num_collapses=num_collapses
		return
	def fit(self, X, y=None):
		"""
		TODO : infer grid from multiple simplextrees
		"""
		return self
	def transform(self,X:Iterable[mp.SimplexTreeMulti]):
		rectangle_decompositions = [
			[_st2ranktensor(
				simplextree, filtration_grid=self.filtration_grid,
				degree=degree,
				plot=self.plot,
				reconvert_grid = self.reconvert_grid,
				num_collapse=self.num_collapses
			) for degree in self.degrees]
			for simplextree in X
		]
		## TODO : return iterator ?
		return rectangle_decompositions


def betti_matrix2signed_measure(betti:coo_array|np.ndarray, grid_conversion:Iterable[np.ndarray]|None = None):
	if isinstance(betti, np.ndarray):   betti = coo_array(betti)
	points_filtration = np.empty(shape=(betti.getnnz(),2), dtype=int) # coo matrix is only for matrices -> 2d
	points_filtration[:,0] = betti.row
	points_filtration[:,1] = betti.col
	weights = np.array(betti.data, dtype=int)
	if grid_conversion is not None:
		coords = np.empty(shape=(betti.getnnz(),2), dtype=float)
		for i in range(2):
			coords[:,i] = grid_conversion[i][points_filtration[:,i]]
	else:
		coords = points_filtration
	return coords, weights


class SimplexTree2SignedMeasure(BaseEstimator,TransformerMixin):
	"""
	TODO
	"""
	def __init__(self, degrees:list[int],filtration_grid:Iterable[np.ndarray]=None, progress=False, num_collapses=100, n_jobs=1, resolution:Iterable[int]=None, sparse=False, plot:bool=False, filtration_quantile:float=0):
		super().__init__()
		self.degrees = degrees
		self.filtration_grid = filtration_grid
		self.progress = progress
		self.num_collapses=num_collapses
		self.n_jobs = n_jobs
		self.resolution = resolution
		self.plot=plot
		self.sparse=sparse
		self.filtration_quantile=filtration_quantile
		assert resolution is not None or filtration_grid is not None
		return
	def fit(self, X, y=None): # Todo : infer filtration grid ?
		if self.filtration_grid is None:
			print("Inferring filtration grid from simplextrees...", end="", flush=True)
			get_filtration_bounds = lambda x : get_simplextree(x).filtration_bounds(q=self.filtration_quantile)
			filtration_bounds =  Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(get_filtration_bounds)(x) for x in X)
			box = np.array([np.min(filtration_bounds, axis=(0,1)), np.max(filtration_bounds, axis=(0,1))])
			self.filtration_grid = [np.linspace(*np.asarray(box)[:,i], num=self.resolution[i]) for i in range(len(box[0]))]
			print("Done.")
		return self
	def transform1(self, simplextree):
		st = get_simplextree(simplextree)
		st = mp.SimplexTreeMulti(st, num_parameters = st.num_parameters) ## COPY
		st.grid_squeeze(filtration_grid = self.filtration_grid, coordinate_values = True)
		if self.num_collapses == "full":
			st.collapse_edges(full=True,max_dimension=1)
		elif isinstance(self.num_collapses, int):
			st.collapse_edges(num=self.num_collapses,max_dimension=1)
		else:
			raise Exception("Bad edge collapse type. either 'full' or an int.")
		signed_measures = []
		st.expansion(np.max(self.degrees)+1)
		for degree in self.degrees:
			rank = mp.hilbert2d(simplextree=st, degree=degree, grid_shape=(len(self.filtration_grid[0]), len(self.filtration_grid[1]))) # TODO : nd ?
			signed_measure = signed_betti(rank, threshold=True, sparse=False) # TODO : sad to deal with enormous matrices, but gaussian convolutions is way faster like this 
			if self.sparse: 
				signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
				grid_conversion=self.filtration_grid, plot = self.plot, num_parameters=len(self.filtration_grid), )
			signed_measures.append(signed_measure)
		return signed_measures
	def transform(self,X):
		# print(X)
		assert self.filtration_grid is not None
		return Parallel(n_jobs=self.n_jobs)(
			delayed(self.transform1)(to_st) for to_st in tqdm(X, disable = not self.progress, desc="Computing Hilbert function")
		)


class SignedMeasure2Img(BaseEstimator,TransformerMixin):
	"""
	TODO
	"""
	def __init__(self, filtration_grid:Iterable[np.ndarray]=None, kernel="gaussian", bandwidth=1., flatten:bool=False, n_jobs:int=1,resolution:int=None, grid_strategy:str="regular",sparse:bool|None = None, **kwargs):
		super().__init__()
		self.kernel=kernel
		self.bandwidth=bandwidth
		self.more_kde_kwargs=kwargs
		self.filtration_grid=filtration_grid
		self.flatten=flatten
		self.n_jobs = n_jobs
		self.resolution = resolution
		self.grid_strategy = grid_strategy
		self.sparse=None # input is either sparse or not.
		return
	def fit(self, X, y=None):
		## Infers if the input is sparse given X 
		if len(X) == 0: return self
		if self.sparse is None:
			if isinstance(X[0], tuple):	self.sparse = True 
			else: self.sparse = False
		if not self.sparse:	return self # in that case, singed measures are matrices, and the grid is already given
		
		## If not sparse : a grid has to be defined
		if self.filtration_grid is None and self.resolution is None:
			raise Exception("Cannot infer filtration grid. Provide either a filtration grid or a resolution.")
		if self.filtration_grid is None:
			pts = np.concatenate([
				sm[0] for signed_measures in X for sm in signed_measures
			])
			self.filtration_grid = infer_grid_from_points(pts, strategy=self.grid_strategy, num=self.resolution)
		
		return self
	def _sm2smi(self, signed_measures:Iterable[np.ndarray]):
		return np.concatenate([
				gaussian_filter(signed_measure, sigma=self.bandwidth, **self.more_kde_kwargs)
			for signed_measure in signed_measures], axis=0)
	def _sm2smi_sparse(self, signed_measures:Iterable[np.ndarray]):
		return np.concatenate([
				_pts_convolution_sparse(
					pts = signed_measure_pts, pts_weights = signed_measure_weights,
					filtration_grid = self.filtration_grid, 
					kernel=self.kernel,
					bandwidth=self.bandwidth,
					**self.more_kde_kwargs
				)
			for signed_measure_pts, signed_measure_weights  in signed_measures], axis=0)
	def transform(self,X):
		if self.sparse is None :	raise Exception("Fit first")
		todo = SignedMeasure2Img._sm2smi_sparse if self.sparse else  SignedMeasure2Img._sm2smi
		out =  Parallel(n_jobs=self.n_jobs)(delayed(todo)(self, signed_measures) for signed_measures in X)
		if self.flatten:
			return [x.flatten() for x in out]
		return out



class SignedMeasure2SlicedWassersteinDistance(BaseEstimator,TransformerMixin):
	"""
	TODO
	"""
	def __init__(self):
		super().__init__()
		return
	def fit(self, X, y=None):
		return self
	def transform(self,X):
		return


def accuracy_to_csv(X,Y,cl, cln:str, k:float=10, dataset:str = "", filtration:str = "", shuffle=True,  verbose:bool=True, **kwargs):
	import pandas as pd
	assert k > 0, "k is either the number of kfold > 1 or the test size > 0."
	if k>1:
		k = int(k)
		from sklearn.model_selection import StratifiedKFold as KFold
		kfold = KFold(k, shuffle=shuffle).split(X,Y)
		accuracies = np.zeros(k)
		for i,(train_idx, test_idx) in enumerate(tqdm(kfold, total=k, desc="Computing kfold")):
			xtrain = [X[i] for i in train_idx]
			ytrain = [Y[i] for i in train_idx]
			cl.fit(xtrain, ytrain)
			xtest = [X[i] for i in test_idx]
			ytest = [Y[i] for i in test_idx] 
			accuracies[i] = cl.score(xtest, ytest)
			if verbose:
				print(f"step {i+1}, {dataset} : {accuracies[i]}", flush=True)
				try:
					print("Best classification parameters : ", cl.best_params_)
				except:
					None
			
	elif k > 0:
		from sklearn.model_selection import train_test_split
		print("Computing accuracy, with train test split", flush=True)
		xtrain, xtest, ytrain, ytest = train_test_split(X, Y, shuffle=shuffle, test_size=k)
		print("Fitting...", end="", flush=True)
		cl.fit(xtrain, ytrain)
		print("Computing score...", end="", flush=True)
		accuracies = cl.score(xtest, ytest)
		try:
			print("Best classification parameters : ", cl.best_params_)
		except:
			None
		print("Done.")
		if verbose:	print(f"Accuracy {dataset} : {accuracies} ")
	file_path:str = f"result_{dataset}.csv".replace("/", "_").replace(".off", "")
	columns:list[str] = ["dataset", "filtration", "pipeline", "cv", "mean", "std"]
	if exists(file_path):
		df:pd.DataFrame = pd.read_csv(file_path)
	else:
		df:pd.DataFrame = pd.DataFrame(columns= columns)
	more_names = []
	more_values = []
	for key, value in kwargs.items():
		if key not in columns:
			more_names.append(key)
			more_values.append(value)
		else:
			warn(f"Duplicate key {key} ! with values {cln} and {value}")
	new_line:pd.DataFrame = pd.DataFrame([[dataset, filtration, cln, k, np.mean(accuracies), np.std(accuracies)]+more_values], columns = columns+more_names)
	df = pd.concat([df, new_line])
	df.to_csv(file_path, index=False)
